#!/usr/bin/env python3
from __future__ import annotations

import argparse
import re
import json
import os
import sys
import time
import shortuuid
import concurrent.futures


SCRIPT_DIR = os.path.dirname(__file__)
REPO_ROOT = os.path.abspath(os.path.join(SCRIPT_DIR, os.pardir, os.pardir))

# Require the upstream LiveBench sibling repo for imports and data outputs
UPSTREAM_LIVEBENCH_ROOT = os.path.abspath(os.path.join(REPO_ROOT, os.pardir, "LiveBench"))
if not (os.path.isdir(UPSTREAM_LIVEBENCH_ROOT) and os.path.exists(os.path.join(UPSTREAM_LIVEBENCH_ROOT, "livebench", "common.py"))):
    sys.stderr.write(
        "[CoThinker] ERROR: Upstream LiveBench not found.\n"
        f"Expected at: {UPSTREAM_LIVEBENCH_ROOT}\n"
        "Please clone/install LiveBench as a sibling directory and run again.\n"
    )
    sys.exit(1)
LIVEBENCH_ROOT = UPSTREAM_LIVEBENCH_ROOT

# Ensure CoThinker repo root is importable
if REPO_ROOT not in sys.path:
    sys.path.insert(0, REPO_ROOT)
if LIVEBENCH_ROOT not in sys.path:
    sys.path.insert(0, LIVEBENCH_ROOT)

from cothinker.engine import CoThinkerConfig, PromptRegistry, CoThinkerEngine
from cothinker.llm_client import OpenAIStyleClient


def _load_dotenv_if_exists(path: str) -> None:
    try:
        if os.path.isfile(path):
            with open(path, "r", encoding="utf-8") as f:
                for line in f:
                    s = line.strip()
                    if not s or s.startswith("#"):
                        continue
                    if "=" in s:
                        k, v = s.split("=", 1)
                        k = k.strip()
                        v = v.strip().strip("\"\'")
                        if k and v and k not in os.environ:
                            os.environ[k] = v
    except Exception:
        pass


def load_templates(templates_dir: str | None = None) -> dict[str, str]:
    env_dir = os.environ.get("COTHINKER_PROMPT_DIR")
    prompts_dir = templates_dir or env_dir or os.path.join(REPO_ROOT, "cothinker", "prompts")
    files = {
        "style_orchestrator": "style_orchestrator.txt",
        "tms": "tms.txt",
        "individual_summarizer": "individual_summarizer.txt",
        "synthesizer": "synthesizer.txt",
        "agent_turn": "agent_turn.txt",
    }
    templates: dict[str, str] = {}
    for key, fname in files.items():
        path = os.path.join(prompts_dir, fname)
        with open(path, "r", encoding="utf-8") as f:
            templates[key] = f.read()
    return templates


def make_model_display_name(base_model: str, cfg: CoThinkerConfig, template: str = "cothinker") -> str:
    if template == "compact":
        sw = "SWon" if cfg.small_world else "SWoff"
        sg = "SG1" if cfg.enable_style_generator else "SG0"
        p = int(round(cfg.rewiring_prob * 100))
        return f"{base_model}-{sw}-A{cfg.num_agents}-R{cfg.num_rounds}-Ref{cfg.num_references}-P{p}-SUM{cfg.summarizer}-{sg}"
    # default template
    parts = [
        "cothinker",
        base_model,
        ("SWon" if cfg.small_world else "SWoff"),
        f"A{cfg.num_agents}",
        f"R{cfg.num_rounds}",
        f"Ref{cfg.num_references}",
        f"P{int(round(cfg.rewiring_prob * 100))}",
        f"SUM_{cfg.summarizer}",
        ("SGon" if cfg.enable_style_generator else "SGoff"),
    ]
    return "-".join(parts)


def sanitize_model_id(model_id: str) -> str:
    # Replace any non filename-safe characters with '-'
    # Avoid creating subdirectories from slashes
    return re.sub(r"[^a-zA-Z0-9_.+\-]", "-", model_id)


def _run_and_write(questions, answer_file, model_display_name, engine, args, reorg_answer_file_fn):
    os.makedirs(os.path.dirname(answer_file), exist_ok=True)
    print(f"Questions: {len(questions)} | Output: {answer_file}")

    # Derive task directory and default trace base dir from the answer file path
    task_dir = os.path.dirname(os.path.dirname(answer_file))
    trace_base_dir = None
    if getattr(args, "save_trace", False):
        if args.trace_dir:
            trace_base_dir = args.trace_dir if os.path.isabs(args.trace_dir) else os.path.join(LIVEBENCH_ROOT, args.trace_dir)
        else:
            trace_base_dir = os.path.join(task_dir, "_traces", model_display_name)
        os.makedirs(trace_base_dir, exist_ok=True)

    # Resolve optional seed trace dir per task if provided. For relative paths, resolve under the task dir.
    seed_trace_dir = None
    if getattr(args, "seed_initial_from_style_tms", False):
        root = getattr(args, "seed_trace_root", None)
        if root:
            if os.path.isabs(root):
                seed_trace_dir = root
            else:
                # resolve relative to the task directory
                seed_trace_dir = os.path.join(task_dir, root)
        else:
            sys.stderr.write("[CoThinker] WARN: --seed-initial-from-style-tms set but --seed-trace-root not provided; seeding disabled for this run.\n")
            seed_trace_dir = None

    def process_one(q):
        try:
            _qid = q.get("question_id")
        except Exception:
            _qid = None
        print(f"[CoThinker] Begin: qid={_qid}, turns={len(q.get('turns', []))}")
        choices = []
        for i in range(args.num_choices):
            turns_out = []
            for j in range(len(q["turns"])):
                if getattr(args, "dry_run", False):
                    result = None
                    turns_out.append("$DUMMY$")
                else:
                    seeded_used = False
                    seeded_answer = None
                    result = None
                    # If seeding enabled, try to load turn-1 final_answer from baseline traces
                    if j == 0 and seed_trace_dir is not None:
                        try:
                            q_seed_dir = os.path.join(seed_trace_dir, str(q["question_id"]))
                            choice_idx = int(getattr(args, "seed_choices_index", 0))
                            seed_path = os.path.join(q_seed_dir, f"turn_1_choice_{choice_idx}.json")
                            if os.path.isfile(seed_path):
                                with open(seed_path, "r", encoding="utf-8") as sf:
                                    seed_payload = json.load(sf)
                                # tolerate various shapes; prefer payload["result"]["final_answer"]
                                if isinstance(seed_payload, dict):
                                    maybe = seed_payload.get("result") if "result" in seed_payload else seed_payload
                                    if isinstance(maybe, dict):
                                        seeded_answer = maybe.get("final_answer")
                                if seeded_answer:
                                    # synthesize a minimal result for trace persistence
                                    result = {
                                        "final_answer": seeded_answer,
                                        "seeded_from": seed_path,
                                    }
                                    turns_out.append(seeded_answer)
                                    seeded_used = True
                        except Exception as se:
                            sys.stderr.write(f"[CoThinker] WARN: failed loading seed for qid={q.get('question_id')}: {se}\n")

                    if not seeded_used:
                        sub_q = {"turns": q["turns"][: j + 1]}
                        result = engine.run(sub_q)
                        turns_out.append(result["final_answer"]) 

                # Persist per-turn trace if requested
                if trace_base_dir is not None:
                    q_dir = os.path.join(trace_base_dir, str(q["question_id"]))
                    os.makedirs(q_dir, exist_ok=True)
                    payload = {
                        "question_id": q["question_id"],
                        "turn_index": j + 1,
                        "choice_index": i,
                        "model": args.model,
                        "engine_config": getattr(engine, "cfg", None).__dict__ if getattr(engine, "cfg", None) else None,
                        "result": result if result is not None else {"final_answer": "$DUMMY$"},
                    }
                    out_path = os.path.join(q_dir, f"turn_{j+1}_choice_{i}.json")
                    try:
                        with open(out_path, "w", encoding="utf-8") as f:
                            json.dump(payload, f, ensure_ascii=False, indent=2)
                    except Exception as e:
                        sys.stderr.write(f"[CoThinker] WARN: failed writing trace {out_path}: {e}\n")
        
            choices.append({"index": i, "turns": turns_out})

        ans = {
            "question_id": q["question_id"],
            "answer_id": shortuuid.uuid(),
            "model_id": model_display_name,
            "choices": choices,
            "tstamp": time.time(),
            "total_output_tokens": 0,
            "api_info": {
                "provider": args.api_base if args.api_base else "openai-compatible",
                "api_name": args.model,
                "api_kwargs": None,
            },
        }
        return json.dumps(ans)

    results_by_idx: dict[int, str] = {}
    total = len(questions)
    done = 0
    qids = []
    try:
        qids = [q.get("question_id") for q in questions]
    except Exception:
        qids = [None] * total
    workers = max(1, int(getattr(args, "parallel", 1)))
    print(f"[CoThinker] Launching outer pool: workers={workers}, tasks={total}")
    with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as ex:
        future_map = {ex.submit(process_one, q): idx for idx, q in enumerate(questions)}
        for fut in concurrent.futures.as_completed(future_map):
            idx = future_map[fut]
            try:
                results_by_idx[idx] = fut.result()
            except Exception as e:
                sys.stderr.write(f"[CoThinker] ERROR: question index {idx} failed: {e}\n")
            finally:
                done += 1
                try:
                    pct = int((done / total) * 100) if total else 100
                except Exception:
                    pct = 0
                qid = qids[idx] if idx < len(qids) else None
                print(f"[CoThinker] Progress: {done}/{total} ({pct}%) - qid={qid}")

    # Write results in question order
    with open(answer_file, "a", encoding="utf-8") as fout:
        for idx in range(len(questions)):
            if idx in results_by_idx:
                fout.write(results_by_idx[idx] + "\n")
    reorg_answer_file_fn(answer_file)


def main():
    parser = argparse.ArgumentParser(description="Run CoThinker engine as a LiveBench-compatible model and write JSONL answers")
    parser.add_argument("--bench-name", type=str, default="live_bench")
    parser.add_argument("--categories", type=str, default="math,reasoning", help="Comma-separated categories to run")
    parser.add_argument("--question-source", type=str, default="huggingface", choices=["huggingface", "jsonl"])
    parser.add_argument("--livebench-release-option", type=str, default=None, help="LiveBench release (validated after importing LiveBench)")

    parser.add_argument("--model", type=str, required=True, help="Underlying chat model id (OpenAI-compatible)")
    parser.add_argument("--api-base", type=str, default=None)
    parser.add_argument("--api-key", type=str, default=None)
    parser.add_argument("--num-choices", type=int, default=1)
    # Observability & reasoning controls
    parser.add_argument("--debug", action="store_true", help="Enable CoThinker debug logs with agent previews")
    parser.add_argument("--preview-chars", type=int, default=160, help="Chars to preview per agent output in debug logs")
    parser.add_argument("--or-reason-budget", type=int, default=None, help="OpenRouter reasoning max tokens budget")
    parser.add_argument("--or-reason-exclude", action="store_true", help="Exclude reasoning text from assistant content (defaults to include)")
    parser.add_argument("--question-begin", type=int, default=None)
    parser.add_argument("--question-end", type=int, default=None)
    parser.add_argument("--question-id", type=str, nargs="+", default=None)
    parser.add_argument("--resume", action="store_true")
    parser.add_argument("--model-display-name", type=str, default=None)
    parser.add_argument("--naming-template", type=str, default="cothinker", choices=["cothinker", "compact"], help="model_id naming template when --model-display-name is not set")
    parser.add_argument("--parallel", type=int, default=20, help="Number of questions to run in parallel")
    parser.add_argument("--save-trace", action="store_true", help="Persist per-turn traces under a _traces directory")
    parser.add_argument("--trace-dir", type=str, default=None, help="Optional base directory for traces; if relative, resolved against LiveBench root")
    # Seeding options: reuse turn-1 answers from an existing style-on+tms trace directory
    parser.add_argument("--seed-initial-from-style-tms", action="store_true", help="Reuse turn-1 final_answer from an existing style-on+tms traces per task")
    parser.add_argument("--seed-trace-root", type=str, default=None, help="Trace directory root for baseline run; if relative, resolved under each task dir (e.g., _traces/<baseline_model_display_name>)")
    parser.add_argument("--seed-choices-index", type=int, default=0, help="Choice index to pick from baseline seeds (default 0)")

    # CoThinker config
    parser.add_argument("--num-agents", type=int, default=6)
    parser.add_argument("--num-rounds", type=int, default=3)
    parser.add_argument("--enable-style-generator", action="store_true")
    parser.add_argument("--summarizer", type=str, default="tms", choices=["tms", "individual", "none"]) 
    parser.add_argument("--max-tokens-style", type=int, default=8192)
    parser.add_argument("--max-tokens-turn", type=int, default=8192)
    parser.add_argument("--max-tokens-summarize", type=int, default=8192)
    parser.add_argument("--max-tokens-synth", type=int, default=8192)
    parser.add_argument("--init-temperature", type=float, default=0.25)
    parser.add_argument("--followup-temperature", type=float, default=0.25)
    parser.add_argument("--small-world", action="store_true")
    parser.add_argument("--ring-k", type=int, default=2)
    parser.add_argument("--num-references", type=int, default=3)
    parser.add_argument("--rewiring-prob", type=float, default=0.3)
    parser.add_argument("--reference-mode", type=str, default="neighbor-embed", choices=["none", "neighbor-embed", "neighbor", "global-embed", "global"])
    parser.add_argument("--embed-model", type=str, default="all-MiniLM-L6-v2")
    parser.add_argument("--embed-device", type=str, default=None)
    parser.add_argument("--dry-run", action="store_true", help="Write dummy outputs for path verification without calling the model")

    args = parser.parse_args()

    # Load .env from CoThinker and LiveBench roots if present
    _load_dotenv_if_exists(os.path.join(REPO_ROOT, ".env"))
    _load_dotenv_if_exists(os.path.join(LIVEBENCH_ROOT, ".env"))

    # Lazy import LiveBench after parsing basic args
    from livebench.common import (
        LIVE_BENCH_DATA_SUPER_PATH,
        LIVE_BENCH_CATEGORIES,
        LIVE_BENCH_RELEASES,
        reorg_answer_file,
        get_categories_tasks,
        load_questions,
        load_questions_jsonl,
        filter_questions,
    )

    # Wire CoThinker debug/preview and OpenRouter reasoning controls via env
    if args.debug:
        os.environ["COTHINKER_DEBUG"] = "1"
    if args.preview_chars is not None:
        os.environ["COTHINKER_PREVIEW_CHARS"] = str(args.preview_chars)
    if args.or_reason_budget is not None:
        os.environ["COTHINKER_OR_REASON_BUDGET"] = str(args.or_reason_budget)
    # Default to not excluding reasoning text
    os.environ["COTHINKER_OR_REASON_EXCLUDE"] = "1" if args.or_reason_exclude else os.environ.get("COTHINKER_OR_REASON_EXCLUDE", "0")

    selected_categories = [c.strip() for c in args.categories.split(",") if c.strip()]
    for c in selected_categories:
        if c not in LIVE_BENCH_CATEGORIES:
            raise ValueError(f"Unknown category {c}. Valid: {', '.join(LIVE_BENCH_CATEGORIES)}")

    cfg = CoThinkerConfig(
        num_agents=args.num_agents,
        num_rounds=args.num_rounds,
        enable_style_generator=args.enable_style_generator,
        summarizer=args.summarizer,
        max_tokens_style=args.max_tokens_style,
        max_tokens_turn=args.max_tokens_turn,
        max_tokens_summarize=args.max_tokens_summarize,
        max_tokens_synth=args.max_tokens_synth,
        init_temperature=args.init_temperature,
        followup_temperature=args.followup_temperature,
        small_world=args.small_world,
        ring_k=args.ring_k,
        num_references=args.num_references,
        rewiring_prob=args.rewiring_prob,
        reference_mode=args.reference_mode,
        embed_model=args.embed_model,
        embed_device=args.embed_device,
    )
    templates = load_templates()
    registry = PromptRegistry(templates)
    client_factory = lambda: OpenAIStyleClient(model=args.model, api_base=args.api_base, api_key=args.api_key)
    engine = CoThinkerEngine(client_factory=client_factory, prompt_registry=registry, config=cfg)

    model_display_name_raw = (args.model_display_name or make_model_display_name(args.model, cfg, args.naming_template)).lower()
    model_display_name = sanitize_model_id(model_display_name_raw)

    if args.livebench_release_option is None:
        args.livebench_release_option = max(LIVE_BENCH_RELEASES)
    if args.livebench_release_option not in LIVE_BENCH_RELEASES:
        raise ValueError(f"Bad release {args.livebench_release_option}.")
    release_set = {r for r in LIVE_BENCH_RELEASES if r <= args.livebench_release_option}

    if args.question_source == "huggingface":
        categories, tasks = get_categories_tasks(args.bench_name)
        for category_name, task_names in tasks.items():
            if category_name not in selected_categories:
                continue
            for task_name in task_names:
                questions = load_questions(
                    categories[category_name],
                    livebench_releases=release_set,
                    livebench_release=args.livebench_release_option,
                    task_name=task_name,
                    question_ids=args.question_id,
                )
                questions = questions[args.question_begin:args.question_end]

                task_full_name = f"{LIVE_BENCH_DATA_SUPER_PATH}/{category_name}/{task_name}"
                answer_file = os.path.join(LIVEBENCH_ROOT, "data", task_full_name, "model_answer", f"{model_display_name}.jsonl")

                questions = filter_questions(questions, answer_file, resume=args.resume, retry_failures=False)
                if len(questions) == 0:
                    print(f"No questions to run for {task_full_name}")
                    continue
                _run_and_write(questions, answer_file, model_display_name, engine, args, reorg_answer_file)

    elif args.question_source == "jsonl":
        base = os.path.join(LIVEBENCH_ROOT, "data", args.bench_name)
        explicit = os.path.join(base, "question.jsonl")
        if os.path.exists(explicit):
            files = [explicit]
        else:
            import glob
            files = glob.glob(os.path.join(base, "**", "question.jsonl"), recursive=True)

        for question_file in files:
            bench_name = os.path.dirname(question_file).replace(os.path.join(LIVEBENCH_ROOT, "data") + os.sep, "")
            try:
                _category = bench_name.split("/")[1]
            except Exception:
                _category = None
            if _category and _category not in selected_categories:
                continue
            questions = load_questions_jsonl(
                question_file,
                livebench_releases=release_set,
                livebench_release=args.livebench_release_option,
                question_ids=args.question_id,
            )
            questions = questions[args.question_begin:args.question_end]

            answer_file = os.path.join(LIVEBENCH_ROOT, "data", bench_name, "model_answer", f"{model_display_name}.jsonl")
            questions = filter_questions(questions, answer_file, resume=args.resume, retry_failures=False)
            if len(questions) == 0:
                print(f"No questions to run for {bench_name}")
                continue
            _run_and_write(questions, answer_file, model_display_name, engine, args, reorg_answer_file)
    else:
        raise ValueError(f"Bad question source {args.question_source}.")


if __name__ == "__main__":
    main()
